Feed-forward neural network

This is a simple tutorial on how to train a feed-forward neural network to predict protein subcellular localization.


In [1]:
# Import all the necessary modules
import os
os.environ["THEANO_FLAGS"] = "mode=FAST_RUN,optimizer=None,device=cpu,floatX=float32"
import sys
sys.path.insert(0,'..')
import numpy as np
import theano
import theano.tensor as T
import lasagne
from confusionmatrix import ConfusionMatrix
from utils import iterate_minibatches
import matplotlib.pyplot as plt
import time
import itertools
%matplotlib inline

Building the network

The first thing that we have to do is to define the network architecture. Here we are going to use an input layer, dense layer and output layer. These are the steps that we are going to follow:

1.- Specify the hyperparameters of the network:


In [2]:
batch_size = 128
seq_len = 400
n_feat = 20
n_hid = 30
n_class = 10
lr = 0.0025
drop_prob = 0.5

2.- Define the input variables to our network:


In [3]:
# We use ftensor3 because the protein data is a 3D-matrix in float32 
input_var = T.ftensor3('inputs')
# ivector because the labels is a single dimensional vector of integers
target_var = T.ivector('targets')
# Dummy data to check the size of the layers during the building of the network
X = np.random.randint(0,10,size=(batch_size,seq_len,n_feat)).astype('float32')

3.- Define the layers of the network:


In [4]:
# Input layer, holds the shape of the data
l_in = lasagne.layers.InputLayer(shape=(batch_size, seq_len, n_feat), input_var=input_var, name='Input')
print('Input layer: {}'.format(
    lasagne.layers.get_output(l_in, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Dense layer with ReLu activation function
l_dense = lasagne.layers.DenseLayer(l_in, num_units=n_hid, name="Dense",
                                    nonlinearity=lasagne.nonlinearities.rectify)
print('Dense layer: {}'.format(
    lasagne.layers.get_output(l_dense, inputs={l_in: input_var}).eval({input_var: X}).shape))

# Output layer with a Softmax activation function
l_out = lasagne.layers.DenseLayer(lasagne.layers.dropout(l_dense, p=drop_prob), num_units=n_class, 
                                  name="Softmax", nonlinearity=lasagne.nonlinearities.softmax)
print('Output layer: {}'.format(
    lasagne.layers.get_output(l_out, inputs={l_in: input_var}).eval({input_var: X}).shape))


Input layer: (128, 400, 20)
Dense layer: (128, 30)
Output layer: (128, 10)

4.- Calculate the prediction and network loss for the training set and update the network weights:


In [5]:
# Get output training, deterministic=False is used for training
prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var}, deterministic=False)

# Calculate the categorical cross entropy between the labels and the prediction
t_loss = T.nnet.categorical_crossentropy(prediction, target_var)

# Training loss
loss = T.mean(t_loss)

# Parameters
params = lasagne.layers.get_all_params([l_out], trainable=True)

# Get the network gradients and perform total norm constraint normalization
all_grads = lasagne.updates.total_norm_constraint(T.grad(loss, params),3)

# Update parameters using ADAM 
updates = lasagne.updates.adam(all_grads, params, learning_rate=lr)

5.- Calculate the prediction and network loss for the validation set:


In [6]:
# Get output validation, deterministic=True is only use for validation
val_prediction = lasagne.layers.get_output(l_out, inputs={l_in: input_var}, deterministic=True)

# Calculate the categorical cross entropy between the labels and the prediction
t_val_loss = lasagne.objectives.categorical_crossentropy(val_prediction, target_var)

# Validation loss 
val_loss = T.mean(t_val_loss)

6.- Build theano functions:


In [7]:
# Build functions
train_fn = theano.function([input_var, target_var], [loss, prediction], updates=updates)
val_fn = theano.function([input_var, target_var], [val_loss, val_prediction])

Load dataset

Once that the network is built, the next step is to load the training and the validation set


In [8]:
# Load the encoded protein sequences, labels and masks
# The masks are not needed for the FFN or CNN models
train = np.load('data/reduced_train.npz')
X_train = train['X_train']
y_train = train['y_train']
mask_train = train['mask_train']
print(X_train.shape)


(2423, 400, 20)

In [9]:
validation = np.load('data/reduced_val.npz')
X_val = validation['X_val']
y_val = validation['y_val']
mask_val = validation['mask_val']
print(X_val.shape)


(635, 400, 20)

Training

Once that the data is ready and the network compiled we can start with the training of the model. Here we define the number of epochs that we want to perform


In [10]:
# Number of epochs
num_epochs = 80

# Lists to save loss and accuracy of each epoch
loss_training = []
loss_validation = []
acc_training = []
acc_validation = []
start_time = time.time()
min_val_loss = float("inf")

# Start training 
for epoch in range(num_epochs):
    
    # Full pass training set
    train_err = 0
    train_batches = 0
    confusion_train = ConfusionMatrix(n_class)

    # Generate minibatches and train on each one of them
    for batch in iterate_minibatches(X_train.astype(np.float32), y_train.astype(np.int32), 
                                     mask_train.astype(np.float32), batch_size, shuffle=True, sort_len=False):
        # Inputs to the network
        inputs, targets, in_masks = batch
        # Calculate loss and prediction
        tr_err, predict = train_fn(inputs, targets)
        train_err += tr_err
        train_batches += 1
        # Get the predicted class, the one with the maximum likelihood
        preds = np.argmax(predict, axis=-1)
        confusion_train.batch_add(targets, preds)
    
    # Average loss and accuracy
    train_loss = train_err / train_batches
    train_accuracy = confusion_train.accuracy()
    cf_train = confusion_train.ret_mat()

    val_err = 0
    val_batches = 0
    confusion_valid = ConfusionMatrix(n_class)

    # Generate minibatches and validate on each one of them, same procedure as before
    for batch in iterate_minibatches(X_val.astype(np.float32), y_val.astype(np.int32), 
                                     mask_val.astype(np.float32), batch_size, shuffle=True, sort_len=False):
        inputs, targets, in_masks = batch
        err, predict_val = val_fn(inputs, targets)
        val_err += err
        val_batches += 1
        preds = np.argmax(predict_val, axis=-1)
        confusion_valid.batch_add(targets, preds)

    val_loss = val_err / val_batches
    val_accuracy = confusion_valid.accuracy()
    cf_val = confusion_valid.ret_mat()
    
    loss_training.append(train_loss)
    loss_validation.append(val_loss)
    acc_training.append(train_accuracy)
    acc_validation.append(val_accuracy)
    
    # Save the model parameters at the epoch with the lowest validation loss
    if min_val_loss > val_loss:
        min_val_loss = val_loss
        np.savez('params/FFN_params.npz', *lasagne.layers.get_all_param_values(l_out))
    
    print("Epoch {} of {} time elapsed {:.3f}s".format(epoch + 1, num_epochs, time.time() - start_time))
    print("  training loss:\t\t{:.6f}".format(train_loss))
    print("  validation loss:\t\t{:.6f}".format(val_loss))
    print("  training accuracy:\t\t{:.2f} %".format(train_accuracy * 100))
    print("  validation accuracy:\t\t{:.2f} %".format(val_accuracy * 100))


Epoch 1 of 80 time elapsed 0.324s
  training loss:		2.107359
  validation loss:		1.859765
  training accuracy:		26.85 %
  validation accuracy:		36.72 %
Epoch 2 of 80 time elapsed 0.647s
  training loss:		1.743450
  validation loss:		1.591770
  training accuracy:		40.95 %
  validation accuracy:		47.97 %
Epoch 3 of 80 time elapsed 0.970s
  training loss:		1.472835
  validation loss:		1.461971
  training accuracy:		50.29 %
  validation accuracy:		56.09 %
Epoch 4 of 80 time elapsed 1.190s
  training loss:		1.282473
  validation loss:		1.300663
  training accuracy:		56.83 %
  validation accuracy:		60.94 %
Epoch 5 of 80 time elapsed 1.415s
  training loss:		1.134756
  validation loss:		1.199360
  training accuracy:		60.03 %
  validation accuracy:		61.56 %
Epoch 6 of 80 time elapsed 1.642s
  training loss:		0.987879
  validation loss:		1.088289
  training accuracy:		66.00 %
  validation accuracy:		67.19 %
Epoch 7 of 80 time elapsed 1.886s
  training loss:		0.900661
  validation loss:		1.047687
  training accuracy:		68.79 %
  validation accuracy:		65.62 %
Epoch 8 of 80 time elapsed 2.126s
  training loss:		0.820768
  validation loss:		0.982008
  training accuracy:		70.64 %
  validation accuracy:		70.47 %
Epoch 9 of 80 time elapsed 2.353s
  training loss:		0.743999
  validation loss:		0.969201
  training accuracy:		72.90 %
  validation accuracy:		70.31 %
Epoch 10 of 80 time elapsed 2.603s
  training loss:		0.671169
  validation loss:		0.925915
  training accuracy:		76.07 %
  validation accuracy:		71.56 %
Epoch 11 of 80 time elapsed 2.848s
  training loss:		0.615192
  validation loss:		0.884622
  training accuracy:		78.37 %
  validation accuracy:		73.28 %
Epoch 12 of 80 time elapsed 3.075s
  training loss:		0.575456
  validation loss:		0.867527
  training accuracy:		78.74 %
  validation accuracy:		74.06 %
Epoch 13 of 80 time elapsed 3.331s
  training loss:		0.512095
  validation loss:		0.864169
  training accuracy:		81.50 %
  validation accuracy:		73.75 %
Epoch 14 of 80 time elapsed 3.569s
  training loss:		0.494531
  validation loss:		0.849973
  training accuracy:		82.03 %
  validation accuracy:		74.06 %
Epoch 15 of 80 time elapsed 3.832s
  training loss:		0.466686
  validation loss:		0.862280
  training accuracy:		82.98 %
  validation accuracy:		72.81 %
Epoch 16 of 80 time elapsed 4.029s
  training loss:		0.434729
  validation loss:		0.852806
  training accuracy:		83.72 %
  validation accuracy:		74.22 %
Epoch 17 of 80 time elapsed 4.256s
  training loss:		0.418157
  validation loss:		0.851349
  training accuracy:		84.25 %
  validation accuracy:		74.22 %
Epoch 18 of 80 time elapsed 4.495s
  training loss:		0.417498
  validation loss:		0.829231
  training accuracy:		84.29 %
  validation accuracy:		75.47 %
Epoch 19 of 80 time elapsed 4.783s
  training loss:		0.395914
  validation loss:		0.869444
  training accuracy:		84.75 %
  validation accuracy:		73.75 %
Epoch 20 of 80 time elapsed 4.984s
  training loss:		0.359312
  validation loss:		0.839163
  training accuracy:		86.55 %
  validation accuracy:		75.47 %
Epoch 21 of 80 time elapsed 5.169s
  training loss:		0.362639
  validation loss:		0.859328
  training accuracy:		86.35 %
  validation accuracy:		74.84 %
Epoch 22 of 80 time elapsed 5.363s
  training loss:		0.351706
  validation loss:		0.861067
  training accuracy:		86.68 %
  validation accuracy:		75.00 %
Epoch 23 of 80 time elapsed 5.638s
  training loss:		0.343546
  validation loss:		0.863921
  training accuracy:		87.46 %
  validation accuracy:		75.78 %
Epoch 24 of 80 time elapsed 5.845s
  training loss:		0.332753
  validation loss:		0.861443
  training accuracy:		87.62 %
  validation accuracy:		75.62 %
Epoch 25 of 80 time elapsed 6.080s
  training loss:		0.315228
  validation loss:		0.904788
  training accuracy:		87.83 %
  validation accuracy:		72.66 %
Epoch 26 of 80 time elapsed 6.345s
  training loss:		0.295694
  validation loss:		0.926737
  training accuracy:		89.14 %
  validation accuracy:		75.31 %
Epoch 27 of 80 time elapsed 6.610s
  training loss:		0.289340
  validation loss:		0.885212
  training accuracy:		88.53 %
  validation accuracy:		74.84 %
Epoch 28 of 80 time elapsed 6.891s
  training loss:		0.305048
  validation loss:		0.888011
  training accuracy:		87.83 %
  validation accuracy:		75.16 %
Epoch 29 of 80 time elapsed 7.081s
  training loss:		0.283399
  validation loss:		0.863941
  training accuracy:		89.76 %
  validation accuracy:		75.94 %
Epoch 30 of 80 time elapsed 7.363s
  training loss:		0.290155
  validation loss:		0.909824
  training accuracy:		88.28 %
  validation accuracy:		76.09 %
Epoch 31 of 80 time elapsed 7.636s
  training loss:		0.257156
  validation loss:		0.963012
  training accuracy:		90.13 %
  validation accuracy:		73.59 %
Epoch 32 of 80 time elapsed 7.901s
  training loss:		0.254546
  validation loss:		0.969289
  training accuracy:		90.58 %
  validation accuracy:		73.91 %
Epoch 33 of 80 time elapsed 8.097s
  training loss:		0.268782
  validation loss:		0.937327
  training accuracy:		89.39 %
  validation accuracy:		75.31 %
Epoch 34 of 80 time elapsed 8.326s
  training loss:		0.245839
  validation loss:		0.969967
  training accuracy:		90.83 %
  validation accuracy:		74.53 %
Epoch 35 of 80 time elapsed 8.544s
  training loss:		0.242339
  validation loss:		1.002957
  training accuracy:		90.54 %
  validation accuracy:		73.75 %
Epoch 36 of 80 time elapsed 8.759s
  training loss:		0.229800
  validation loss:		0.975329
  training accuracy:		91.04 %
  validation accuracy:		74.22 %
Epoch 37 of 80 time elapsed 9.047s
  training loss:		0.232424
  validation loss:		0.981292
  training accuracy:		90.91 %
  validation accuracy:		75.47 %
Epoch 38 of 80 time elapsed 9.400s
  training loss:		0.215324
  validation loss:		0.956630
  training accuracy:		90.87 %
  validation accuracy:		75.62 %
Epoch 39 of 80 time elapsed 9.731s
  training loss:		0.221702
  validation loss:		1.013137
  training accuracy:		91.45 %
  validation accuracy:		75.00 %
Epoch 40 of 80 time elapsed 10.152s
  training loss:		0.216461
  validation loss:		0.996233
  training accuracy:		91.57 %
  validation accuracy:		75.94 %
Epoch 41 of 80 time elapsed 10.565s
  training loss:		0.232326
  validation loss:		1.015741
  training accuracy:		90.54 %
  validation accuracy:		75.47 %
Epoch 42 of 80 time elapsed 10.913s
  training loss:		0.201155
  validation loss:		1.080081
  training accuracy:		92.19 %
  validation accuracy:		75.00 %
Epoch 43 of 80 time elapsed 11.266s
  training loss:		0.201297
  validation loss:		1.054666
  training accuracy:		92.23 %
  validation accuracy:		75.31 %
Epoch 44 of 80 time elapsed 11.608s
  training loss:		0.195273
  validation loss:		1.029481
  training accuracy:		92.15 %
  validation accuracy:		75.00 %
Epoch 45 of 80 time elapsed 11.930s
  training loss:		0.185151
  validation loss:		1.049541
  training accuracy:		92.23 %
  validation accuracy:		75.62 %
Epoch 46 of 80 time elapsed 12.213s
  training loss:		0.192719
  validation loss:		0.998281
  training accuracy:		91.78 %
  validation accuracy:		76.09 %
Epoch 47 of 80 time elapsed 12.486s
  training loss:		0.176988
  validation loss:		1.047978
  training accuracy:		93.17 %
  validation accuracy:		75.78 %
Epoch 48 of 80 time elapsed 12.760s
  training loss:		0.192038
  validation loss:		1.149443
  training accuracy:		92.35 %
  validation accuracy:		75.31 %
Epoch 49 of 80 time elapsed 13.031s
  training loss:		0.192546
  validation loss:		1.049649
  training accuracy:		91.65 %
  validation accuracy:		76.09 %
Epoch 50 of 80 time elapsed 13.301s
  training loss:		0.190402
  validation loss:		1.087051
  training accuracy:		92.15 %
  validation accuracy:		76.41 %
Epoch 51 of 80 time elapsed 13.573s
  training loss:		0.172711
  validation loss:		1.063931
  training accuracy:		92.27 %
  validation accuracy:		75.31 %
Epoch 52 of 80 time elapsed 13.853s
  training loss:		0.183372
  validation loss:		1.082522
  training accuracy:		92.72 %
  validation accuracy:		75.47 %
Epoch 53 of 80 time elapsed 14.128s
  training loss:		0.189659
  validation loss:		1.071080
  training accuracy:		92.11 %
  validation accuracy:		76.41 %
Epoch 54 of 80 time elapsed 14.394s
  training loss:		0.175962
  validation loss:		1.136218
  training accuracy:		92.72 %
  validation accuracy:		75.00 %
Epoch 55 of 80 time elapsed 14.667s
  training loss:		0.169955
  validation loss:		1.123299
  training accuracy:		92.76 %
  validation accuracy:		76.09 %
Epoch 56 of 80 time elapsed 14.934s
  training loss:		0.172079
  validation loss:		1.111096
  training accuracy:		92.89 %
  validation accuracy:		75.47 %
Epoch 57 of 80 time elapsed 15.201s
  training loss:		0.182867
  validation loss:		1.086331
  training accuracy:		92.52 %
  validation accuracy:		75.94 %
Epoch 58 of 80 time elapsed 15.501s
  training loss:		0.186257
  validation loss:		1.180740
  training accuracy:		92.35 %
  validation accuracy:		75.62 %
Epoch 59 of 80 time elapsed 15.844s
  training loss:		0.166665
  validation loss:		1.195616
  training accuracy:		93.01 %
  validation accuracy:		75.47 %
Epoch 60 of 80 time elapsed 16.193s
  training loss:		0.166927
  validation loss:		1.172126
  training accuracy:		92.97 %
  validation accuracy:		75.31 %
Epoch 61 of 80 time elapsed 16.510s
  training loss:		0.153975
  validation loss:		1.120659
  training accuracy:		93.50 %
  validation accuracy:		76.41 %
Epoch 62 of 80 time elapsed 16.868s
  training loss:		0.172289
  validation loss:		1.136208
  training accuracy:		93.05 %
  validation accuracy:		75.16 %
Epoch 63 of 80 time elapsed 17.166s
  training loss:		0.154623
  validation loss:		1.163755
  training accuracy:		93.71 %
  validation accuracy:		74.53 %
Epoch 64 of 80 time elapsed 17.465s
  training loss:		0.165401
  validation loss:		1.187017
  training accuracy:		93.01 %
  validation accuracy:		75.31 %
Epoch 65 of 80 time elapsed 17.756s
  training loss:		0.150237
  validation loss:		1.197934
  training accuracy:		94.16 %
  validation accuracy:		75.47 %
Epoch 66 of 80 time elapsed 18.048s
  training loss:		0.159940
  validation loss:		1.408773
  training accuracy:		93.96 %
  validation accuracy:		75.31 %
Epoch 67 of 80 time elapsed 18.337s
  training loss:		0.174225
  validation loss:		1.164674
  training accuracy:		92.56 %
  validation accuracy:		75.16 %
Epoch 68 of 80 time elapsed 18.625s
  training loss:		0.158366
  validation loss:		1.217839
  training accuracy:		93.38 %
  validation accuracy:		75.47 %
Epoch 69 of 80 time elapsed 18.913s
  training loss:		0.147981
  validation loss:		1.317361
  training accuracy:		94.04 %
  validation accuracy:		74.69 %
Epoch 70 of 80 time elapsed 19.206s
  training loss:		0.163380
  validation loss:		1.231518
  training accuracy:		93.13 %
  validation accuracy:		75.78 %
Epoch 71 of 80 time elapsed 19.508s
  training loss:		0.149447
  validation loss:		1.280646
  training accuracy:		93.75 %
  validation accuracy:		75.31 %
Epoch 72 of 80 time elapsed 19.809s
  training loss:		0.156145
  validation loss:		1.239825
  training accuracy:		93.79 %
  validation accuracy:		76.41 %
Epoch 73 of 80 time elapsed 20.171s
  training loss:		0.155871
  validation loss:		1.199442
  training accuracy:		93.38 %
  validation accuracy:		75.47 %
Epoch 74 of 80 time elapsed 20.492s
  training loss:		0.150546
  validation loss:		1.254731
  training accuracy:		93.42 %
  validation accuracy:		75.62 %
Epoch 75 of 80 time elapsed 20.815s
  training loss:		0.151264
  validation loss:		1.292623
  training accuracy:		93.59 %
  validation accuracy:		75.47 %
Epoch 76 of 80 time elapsed 21.129s
  training loss:		0.150849
  validation loss:		1.229852
  training accuracy:		94.00 %
  validation accuracy:		74.69 %
Epoch 77 of 80 time elapsed 21.425s
  training loss:		0.142646
  validation loss:		1.221507
  training accuracy:		94.82 %
  validation accuracy:		76.09 %
Epoch 78 of 80 time elapsed 21.735s
  training loss:		0.138869
  validation loss:		1.258508
  training accuracy:		94.28 %
  validation accuracy:		75.94 %
Epoch 79 of 80 time elapsed 22.040s
  training loss:		0.141389
  validation loss:		1.314633
  training accuracy:		93.91 %
  validation accuracy:		75.16 %
Epoch 80 of 80 time elapsed 22.347s
  training loss:		0.146385
  validation loss:		1.308442
  training accuracy:		94.37 %
  validation accuracy:		76.25 %

In [11]:
print("Minimum validation loss: {:.6f}".format(min_val_loss))


Minimum validation loss: 0.829231

Model loss and accuracy

Here we plot the loss and the accuracy for the training and validation set at each epoch.


In [12]:
x_axis = range(num_epochs)
plt.figure(figsize=(8,6))
plt.plot(x_axis,loss_training)
plt.plot(x_axis,loss_validation)
plt.xlabel('Epoch')
plt.ylabel('Error')
plt.legend(('Training','Validation'));



In [13]:
plt.figure(figsize=(8,6))
plt.plot(x_axis,acc_training)
plt.plot(x_axis,acc_validation)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend(('Training','Validation'));


Confusion matrix

The confusion matrix allows us to visualize how well is predicted each class and which are the most common misclassifications.


In [14]:
# Plot confusion matrix 
# Code based on http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html

plt.figure(figsize=(8,8))
cmap=plt.cm.Blues   
plt.imshow(cf_val, interpolation='nearest', cmap=cmap)
plt.title('Confusion matrix validation set')
plt.colorbar()
tick_marks = np.arange(n_class)
classes = ['Nucleus','Cytoplasm','Extracellular','Mitochondrion','Cell membrane','ER',
           'Chloroplast','Golgi apparatus','Lysosome','Vacuole']

plt.xticks(tick_marks, classes, rotation=60)
plt.yticks(tick_marks, classes)

thresh = cf_val.max() / 2.
for i, j in itertools.product(range(cf_val.shape[0]), range(cf_val.shape[1])):
    plt.text(j, i, cf_val[i, j],
             horizontalalignment="center",
             color="white" if cf_val[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True location')
plt.xlabel('Predicted location');



In [ ]: